import numpy as np
from scipy.stats import norm
import scipy
import torch
from diffusers.utils.torch_utils import randn_tensor


class GTWatermarkCached:
    def __init__(self, device, gt_patch, watermarking_mask):
        self.gt_patch = gt_patch.to(device)
        self.watermarking_mask = watermarking_mask.to(device)

    def inject_watermark(self, latents): 
        latents_fft = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))
        # latents_fft[self.watermarking_mask] = self.gt_patch[self.watermarking_mask].clone()
        latents_fft = latents_fft * ~(self.watermarking_mask) + self.gt_patch * self.watermarking_mask
        latents_w = torch.fft.ifft2(torch.fft.ifftshift(latents_fft, dim=(-1, -2))).real
        return latents_w

    # FIXME: Only keeping this here to avoid compilation issues
    # the probability of being watermarked
    def one_minus_p_value(self, latents):
        raise NotImplementedError()

    def tree_ring_p_value(self, latents):
        target_patch = self.gt_patch[self.watermarking_mask].flatten()
        target_patch = torch.concatenate([target_patch.real, target_patch.imag])

        reversed_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))[self.watermarking_mask].flatten()
        reversed_latents_w_fft = torch.concatenate([reversed_latents_w_fft.real, reversed_latents_w_fft.imag])
        
        sigma_w = reversed_latents_w_fft.std()
        lambda_w = (target_patch ** 2 / sigma_w ** 2).sum().item()
        x_w = (((reversed_latents_w_fft - target_patch) / sigma_w) ** 2).sum().item()
        p_w = scipy.stats.ncx2.cdf(x=x_w, df=len(target_patch), nc=lambda_w)
        return p_w